{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from exp.nb_10c import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imagenet(te) training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=1681)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = datasets.untar_data(datasets.URLs.IMAGENETTE_160)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size = 128\n",
    "tfms = [make_rgb, RandomResizedCrop(size, scale=(0.35,1)), np_to_float, PilRandomFlip()]\n",
    "\n",
    "bs = 64\n",
    "\n",
    "il = ImageList.from_files(path, tfms=tfms)\n",
    "sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))\n",
    "ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())\n",
    "\n",
    "ll.valid.x.tfms = [make_rgb, CenterCrop(size), np_to_float]\n",
    "\n",
    "data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## XResNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=1701)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def noop(x): return x\n",
    "\n",
    "class Flatten(nn.Module):\n",
    "    def forward(self, x): return x.view(x.size(0), -1)\n",
    "\n",
    "def conv(ni, nf, ks=3, stride=1, bias=False):\n",
    "    return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "act_fn = nn.ReLU(inplace=True)\n",
    "\n",
    "def init_cnn(m):\n",
    "    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)\n",
    "    if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)\n",
    "    for l in m.children(): init_cnn(l)\n",
    "\n",
    "def conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True):\n",
    "    bn = nn.BatchNorm2d(nf)\n",
    "    nn.init.constant_(bn.weight, 0. if zero_bn else 1.)\n",
    "    layers = [conv(ni, nf, ks, stride=stride), bn]\n",
    "    if act: layers.append(act_fn)\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ResBlock(nn.Module):\n",
    "    def __init__(self, expansion, ni, nh, stride=1):\n",
    "        super().__init__()\n",
    "        nf,ni = nh*expansion,ni*expansion\n",
    "        layers  = [conv_layer(ni, nh, 3, stride=stride),\n",
    "                   conv_layer(nh, nf, 3, zero_bn=True, act=False)\n",
    "        ] if expansion == 1 else [\n",
    "                   conv_layer(ni, nh, 1),\n",
    "                   conv_layer(nh, nh, 3, stride=stride),\n",
    "                   conv_layer(nh, nf, 1, zero_bn=True, act=False)\n",
    "        ]\n",
    "        self.convs = nn.Sequential(*layers)\n",
    "        self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)\n",
    "        self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)\n",
    "\n",
    "    def forward(self, x): return act_fn(self.convs(x) + self.idconv(self.pool(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class XResNet(nn.Sequential):\n",
    "    @classmethod\n",
    "    def create(cls, expansion, layers, c_in=3, c_out=1000):\n",
    "        nfs = [c_in, (c_in+1)*8, 64, 64]\n",
    "        stem = [conv_layer(nfs[i], nfs[i+1], stride=2 if i==0 else 1)\n",
    "            for i in range(3)]\n",
    "\n",
    "        nfs = [64//expansion,64,128,256,512]\n",
    "        res_layers = [cls._make_layer(expansion, nfs[i], nfs[i+1],\n",
    "                                      n_blocks=l, stride=1 if i==0 else 2)\n",
    "                  for i,l in enumerate(layers)]\n",
    "        res = cls(\n",
    "            *stem,\n",
    "            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n",
    "            *res_layers,\n",
    "            nn.AdaptiveAvgPool2d(1), Flatten(),\n",
    "            nn.Linear(nfs[-1]*expansion, c_out),\n",
    "        )\n",
    "        init_cnn(res)\n",
    "        return res\n",
    "\n",
    "    @staticmethod\n",
    "    def _make_layer(expansion, ni, nf, n_blocks, stride):\n",
    "        return nn.Sequential(\n",
    "            *[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1)\n",
    "              for i in range(n_blocks)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def xresnet18 (**kwargs): return XResNet.create(1, [2, 2,  2, 2], **kwargs)\n",
    "def xresnet34 (**kwargs): return XResNet.create(1, [3, 4,  6, 3], **kwargs)\n",
    "def xresnet50 (**kwargs): return XResNet.create(4, [3, 4,  6, 3], **kwargs)\n",
    "def xresnet101(**kwargs): return XResNet.create(4, [3, 4, 23, 3], **kwargs)\n",
    "def xresnet152(**kwargs): return XResNet.create(4, [3, 8, 36, 3], **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=2515)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cbfs = [partial(AvgStatsCallback,accuracy), ProgressCallback, CudaCallback,\n",
    "        partial(BatchTransformXCallback, norm_imagenette),\n",
    "#         partial(MixUp, alpha=0.2)\n",
    "       ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = LabelSmoothingCrossEntropy()\n",
    "arch = partial(xresnet18, c_out=10)\n",
    "opt_func = adam_opt(mom=0.9, mom_sqr=0.99, eps=1e-6, wd=1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def get_batch(dl, learn):\n",
    "    learn.xb,learn.yb = next(iter(dl))\n",
    "    learn.do_begin_fit(0)\n",
    "    learn('begin_batch')\n",
    "    learn('after_fit')\n",
    "    return learn.xb,learn.yb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to replace the old `model_summary` since it used to take a `Runner`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def model_summary(model, data, find_all=False, print_mod=False):\n",
    "    xb,yb = get_batch(data.valid_dl, learn)\n",
    "    mods = find_modules(model, is_lin_layer) if find_all else model.children()\n",
    "    f = lambda hook,mod,inp,out: print(f\"====\\n{mod}\\n\" if print_mod else \"\", out.shape)\n",
    "    with Hooks(mods, f) as hooks: learn.model(xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(arch(), data, loss_func, lr=1, cb_funcs=cbfs, opt_func=opt_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>valid_accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " torch.Size([128, 32, 64, 64])\n",
      " torch.Size([128, 64, 64, 64])\n",
      " torch.Size([128, 64, 64, 64])\n",
      " torch.Size([128, 64, 32, 32])\n",
      " torch.Size([128, 64, 32, 32])\n",
      " torch.Size([128, 128, 16, 16])\n",
      " torch.Size([128, 256, 8, 8])\n",
      " torch.Size([128, 512, 4, 4])\n",
      " torch.Size([128, 512, 1, 1])\n",
      " torch.Size([128, 512])\n",
      " torch.Size([128, 10])\n"
     ]
    }
   ],
   "source": [
    "learn.model = learn.model.cuda()\n",
    "model_summary(learn.model, data, print_mod=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "arch = partial(xresnet34, c_out=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(arch(), data, loss_func, lr=1, cb_funcs=cbfs, opt_func=opt_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>valid_accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit(1, cbs=[LR_Find(), Recorder()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEACAYAAACj0I2EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8VNX5x/HPyU7IBiQhQED2RVEQI2pRccHlp9a9aF2rVaut2la72Nq99dfF1lbrhi1udas/930FVBaVAIKIIAEChC0J2ddZ7vn9MUlIICuZmcwdvu/XixfMnTtzn0OSJ88899xzjbUWERFxv5i+DkBERIJDCV1EJEoooYuIRAkldBGRKKGELiISJZTQRUSihBK6iEiUUEIXEYkSSugiIlFCCV1EJErEhfNgmZmZduTIkeE8pIiI6y1btqzUWpvV1X5hTegjR44kPz8/nIcUEXE9Y8zm7uynlouISJRQQhcRiRJK6CIiUUIJXUQkSiihi4hEiS4TujHmYWNMsTFmdattdxpj1hpjVhljXjTGZIQ2TBER6Up3KvRHgdP32vYuMNlaexjwFfCzIMclIhIVKuu8vPPFTkprGkN+rC4TurX2Q6Bsr23vWGt9TQ8/BnJDEJuIiOsVlNRw3X+WsXpbZciPFYwe+tXAm0F4HxGRqOP1OwAkxIb+lGWvjmCMuR3wAU92ss91xph8Y0x+SUlJbw4nIuI6Pr8FIC6SE7ox5krgLOBSa63taD9r7UPW2jxrbV5WVpdLEYiIRBWvE6jQ42NNyI+1X2u5GGNOB34KzLTW1gU3JBGR6OH1NSf0CKjQjTFPA0uACcaYImPMt4F7gVTgXWPMZ8aYB0Mcp4iIK/mcQAMjHAm9ywrdWvvNdjbPDUEsIiJRp/mkaFwYWi66UlREJIS8TSdF42MioOUiIiL7r7lCj49ThS4i4mq+5paLKnQREXdrbrlE/IVFIiLSOZ0UFRGJEs3TFpXQRURcztN8YZF66CIi7uZzHGJjDDExqtBFRFzN57dhWccFlNBFRELK43fC0m4BJXQRkZDy+W1YToiCErqISEh5/U5YFuYCJXQRkZDy+q0SuohINAhU6Gq5iIi4ns9xwnL7OVBCFxEJKbVcRESihFouIiJRwue3xIXhKlFQQhcRCSmPpi2KiEQHnxK6iEh08Dlay0VEJCp4fJq2KCISFVShi4hECa3lIiISJQLTFpXQRURcz+t3SIhTy0VExPW8fkcVuohINPBpLRcRkejg0VouIiLRwedE0C3ojDEPG2OKjTGrW20baIx51xizvunvAaENU0TEfRzH4nciq+XyKHD6XttuA9631o4D3m96LCIirXgdByByErq19kOgbK/N5wCPNf37MeDcIMclIuJ6Pr8FiPge+mBr7Q6Apr+zgxeSiEh08PoDFXrUTFs0xlxnjMk3xuSXlJSE+nAiIhHD65IKfZcxZghA09/FHe1orX3IWptnrc3Lysraz8OJiLhPc4UeMT30DrwCXNn07yuBl4MTjohI9GjuoUfM8rnGmKeBJcAEY0yRMebbwJ+AU4wx64FTmh6LiEgre2a5hKflEtfVDtbab3bw1MlBjkVEJKq4peUiIiJdaGm5xET2SVEREemCp7lCj1OFLiLiai0XFkXLPHQRkQOVzx/ek6JK6CIiIdLccomYaYsiIrJ/mlsuCUroIiLu1rKWi1ouIiLu5nXcsZaLiIh0wevThUUiIlHB5+ikqIhIVHDL8rkiItKFlrVcdGGRiIi77Vk+VxW6iIirebTaoohIdNhzk2gldBERV/M5DjEGYrV8roiIu3n8TtimLIISuohIyPj8lvgwVeeghC4iEjJevxO2m1uAErqISMh4/Za4MM1BByV0EZGQ8fodEsI0Bx2U0EVEQsank6IiItHB69iwreMCSugiIiHj9Tlhu6gIlNBFRELG59iwreMCSugiIiHj9atCFxGJCl6/E7alc0EJXUQkZHx+S3ycWi4iIq7n9Tu6sEhEJBp4/S6atmiM+aEx5gtjzGpjzNPGmKRgBSYi4nauOSlqjBkG3AzkWWsnA7HAxcEKTETE7QLTFl2Q0JvEAf2MMXFAMrC99yGJiESHQIXugpaLtXYb8FdgC7ADqLTWvrP3fsaY64wx+caY/JKSkv2PVETEZVwzbdEYMwA4BxgFDAX6G2Mu23s/a+1D1to8a21eVlbW/kcqIuIyPr97rhSdBWyy1pZYa73AC8DXghOWiIj7edxyUpRAq+VoY0yyMcYAJwNfBicsERH387ll2qK19hPgOWA58HnTez0UpLhERFwv3NMW43rzYmvtr4FfBykWEZGoYa113bRFERFph8+xALoFnYiI23n9DoAqdBERt/P6AxV6XIwqdBERV2uu0BPiVKGLiLiar6VCV0IXEXG15grdFfPQRUSkY3sSuip0ERFXa5626Ja1XEREpAMenyp0EZGo0Fyhq4cuIuJy6qGLiESJlitFNW1RRMTdmuehq+UiIuJyarmIiESJlrVcVKGLiLhby1ouqtBFRNzN52j5XBGRqODVSVERkeigk6IiIlHCpxtciIhEh5YKXTe4EBFxt5Yeuq4UFRFxN59ucCEiEh2aWy6x6qGLiLib17HExxqMUUIXEXE1r88J65RFUEIXEQkJn2PDOmURlNBFRELC43dICOOURVBCFxEJCZ/fCevNLaCXCd0Yk2GMec4Ys9YY86Ux5phgBSYi4mY+vyU+Lrwtl7hevv5u4C1r7YXGmAQgOQgxiYi4nsfvhPWiIuhFQjfGpAHHA98CsNZ6AE9wwhIRcTef34b15hbQu5bLaKAEeMQYs8IY829jTP8gxSUi4mpev7umLcYB04AHrLWHA7XAbXvvZIy5zhiTb4zJLykp6cXhRETcw+vYsN7cAnqX0IuAImvtJ02PnyOQ4Nuw1j5krc2z1uZlZWX14nAiIu7h8zskuKXlYq3dCWw1xkxo2nQysCYoUYmIuJy3D6Yt9naWy03Ak00zXDYCV/U+JBER9/P6LUnxLpq2aK39DMgLUiwiIlHD63dIcFEPXUTkgFVQXM0f3/yS8tr2Z2u7bdqiiMgB66UV25nzwUbOvOcjlm0u3+d5r+OuaYsiIgessjoP/RNiiYuN4aI5S3joww04jm153m3z0EVEDlgVdR5y0pN47eZjOeXgwfzvG2t5eeW2lud9fi2fKyLiCmW1HgYkJ5CWFM99l0wjPtbw1a6alue9fod4LZ8rIhL5Kuq8DOifAEBMjCErJZHiqsaW571+S7wqdBGRyBeo0ONbHmelJlJc3dDyWD10EREXsNa2qdABslKTKKneU6EHpi0qoYuIRLQ6jx+P32FA8p6Enp2W2JLQrbV4nfCv5dLbS/9FRA44ZU0XEw1sldCzUhLZXevB63cwgLWoQhcRiXQVdV4AMlr10LPTEgHYXePB1zQfXVeKiohEuLK6pgq9VQ89OzUJgOLqBjx+B0BruYiIRLqKpoSe0brlkhqo0IurGvH5myp0TVsUEYlszT301tMWs5sSeklNI76mCl0XFomIRLjyOi/GQHq/PQk9M2VPhd7ccokP8w0ulNBFRHqoos5DWlJ8m1ksCXExDOyfQHF1w56Wi06KiohEtrJaT5sTos2yUhIprm7E21yh66SoiEhkq6jztpmy2Kz54iJvU4UerwpdRCSyldV62lxU1CwrtTmhq0IXEXGFijpPmymLzfZO6LpSVEQkwpXVeRjYv52WS2oSHr9DaU1gWqOWzxURiWANXj8NXqfdCr15Lvr2inpA89BFRCJaeV3zRUXtt1wAtjUldF0pKiISwVpWWmy35bJXha4euohI5Nqz0mI7LZe0wAJd2ysDdy5SQhcRiWB7KvR9E3pKYhzJCbFsK2+u0NVyERGJWHtWWty35QKBPnppTeDORarQRUQiWFltoOXS3klR2NNHB63lIiIS0crrPKQmxnVYfTff6AJUoYuIRLSKOg8D2umfN8tqVaG7bvlcY0ysMWaFMea1YAQkIhLJyuq8bW5ssbc2CT3OfS2X7wNfBuF9REQiXlcVepseupsqdGNMLnAm8O/ghCMiEtnKaj0dnhCFvSp0l50U/QfwE8AJQiwiIhGvo7XQmzWfFI2LMRjjkoRujDkLKLbWLutiv+uMMfnGmPySkpL9PZyISJ/z+BxqGn3troXeLDstUKGHe8oi9K5CnwGcbYwpBJ4BTjLGPLH3Ttbah6y1edbavKysrF4cTkSkb7VcVNRJD31gcgKxMSbsUxahFwndWvsza22utXYkcDEwz1p7WdAiExGJMGVNCb2zCj0mxpCZkuCuhC4icqApb7lKtOMeOgT66OFeOheClNCttQustWcF471ERCJVc8uls2mLEJi62BcVelzYjygi4lJlndzcorUzDh3CuMGp4QipDSV0EZFu2rMWeuctlwuOyA1HOPtQD11EpJvKaj0kJ8SSFB/b16G0SwldRKSbyus6v0q0rymhi4h0U3mthwHt3Es0Uiihi4h0U3mdVxW6iIjb1Xl87KxsUEIXEXGzFVvKOfOeheyqbuDEiZG7hImmLYqIdMDvWO5+fz33zS8gJy2Jp645mmPGDOrrsDqkhC4i0oG3Vu/knvfXc97hw/jtOYeQlhS5J0RBCV1EpENLNpaSkhjHnRceRlwfXMrfU5EfoYhIH8kvLOfwERmuSOaghC4i0q7KOi/rdlUzfeTAvg6l25TQRUTasWxLGdZCnhK6iIi7LS0sJz7WMHV4Rl+H0m1K6CIi7Vi6qYzJw9LplxCZC3G1RwldRDrl9TvkF5Zx3/wClm8p7+twwqLB62dVUSVHuqjdApq2KCKt/O8bX/Layu2kJyeQ0S+emBhYsaWCOo8fgGEZ/Xj/1pn7LB/7ycbdZCQnMH5wCsaE/9ZrwbaqqBKP31FCFxF32lpWx9yFm5g8LJ2slEQq6z00NDpcMC2XGWMHYS3c8ORy/rNkM9ceP7rldR9+VcIVD38KwND0JGZOyOKiI0e4qve8t6WFZQAccdCAPo6kZ5TQRVxkVVEFJdWNnDxpcNDfe+7CTcQYePCyaQxJ79fuPjPHZ/HPeev5Rl4uGckJVNZ7+enzqxibncK3jx3FB+tKeHXlDl5duYNPbz+Z5IS2KeY7/8nHYHjgsmkRXcnnF5YxNjuFgV3cOzTSqIcu4iK/eGk1Nz61gppGX1Dft6zWwzNLt3DO1GEdJnOA2/5nItWNPu5fsAGA3726huLqRv72jSl8c/oIHrz8COZemUdNo4+3v9jZ5rWFpbW8/cUu3vpiJ/PWFgc1/mDyO5b8zeWua7dAFCR0ay1+x/Z1GCIht7WsjlVFldR7/byxakdQ3/vRxYU0eB2unzm60/0mDUnjgmm5PLqokEcXbeL55UXcMHMMU1q1V44cOZARA5N5bllRm9c+m7+VGAPDB/bj96+tweNzgjqGvS0qKOXyuZ9QVuvp0eu+2lVNdYOPI0e6q90CLk7o1lpeW7WdE/+6gHG3v8Exf3yfCx5YzC3PfsbnRZVBO86m0lpWFVXg9OKXxvaKetbvqsba4P/iWbuzinPuW8Tlcz8JetUmkeX1zwNJPDs1cZ9k2Rt1Hh+PLynklIMHMza76zvV33LKeIyB37y6hok5qdx88rg2z8fEGM6fNozFG3ZTVF4HgM/v8NyyIk6ckM3vz5lM4e46Hlm0qd33t9Yy54MNfOuRT6lu8O7XmDw+h9tf/JyP1pfy21e/6NFr85v6526s0F3ZQ/94427++MaXrCyqZMLgVG44YQw7KxvZUVnPe2t28cLybZx+SA63njqecYNTcRxLVYOXijovu2s9lNd6qPX4OGliNqkdrJ7m9TvcN7+Af84rwO9YMlMSmDk+m+PHZzIxJ42Rmckkxu0502+txeu3xMcajDE0+vy8u2YX/126lYUFpVgLWamJzBgziKnDMyit8VC4u5bNu+to8PqJj40hPi6GtKQ4Dh6axpTcDKYMz2BYRvsff/2O5aEPN/L3d7+if2IsVQ0+rpj7CY9ePT3iV4ST/fP6qh1MyU3ntMk5/OWtdRSW1jIys3+v3/eZT7dSUefl+pljurX/0Ix+fOf40cz5cCN3zZ5KQty+deEF03L5x3vreXH5Nm46eRwL1pVQXN3IRUcO54QJ2Zw8MZt/zivgvGnDyE5Nanmd1+/wixdX89/8rQD84JnPeOiKPGJjetZvf+qTzRTuruO4cZm8/Nl2zjpsKKcc3P55B79jmbe2mPXF1Wwtq2NhQSk5aUnkDui49RSpTCiqxo7k5eXZ/Pz8/X59vcfPHW+s4YmPtzAkPYlbThnP+dNy23yxqxu8zF24iX9/tIk6j4+0fvFU1ntpb5iHj8jgqWuO3ufCgY0lNfzw2ZWs3FrBeYcP4/jxmSxYV8IHX5VQUReoGGJjDCMGJmNMYM2HinpvS+snoWkhH4/fYWh6Et/IG87QjCQWb9jNooJSSms8xMYYcgf046BB/UlJjMXjc/D4LWW1jazbWY3XH3ivUw4ezB3nTW7zTb96WyW/fHk1K7ZU8D+Tc/jDuZNZWljOTU8vZ2JOGo9fPZ0BLjuZE+nqPL7AL90+WqRpy+46jr9zPj8/YyJnTxnG1/70Pt87cSy3njqhx+/1wIINPL6kkIH9ExiclsTKrRWMyU7h2e8c0+33sNZSVe8jPbnj4uHih5aws7KB+T86gWsfX8bKogoW33YS8bExbCqt5dS/f8DXpwzl12cdQmpSHLUeH999cjkfrS/lppPGkpmSyK9f+YLvnjCGn5w+sduxVdZ7OeHO+UwaksajV03n7HsXUlbr4d1bZpLer228K7aU88uXV7N6WxUAg/onMHxgMt/Iy+XSow7q9jFDzRizzFqb19V+rqnQ12yv4uZnVlBQXMO1x43i1lMn7DMXFiA1KZ4fzBrPFceM5NHFhYGbuibHk5GcQEZyPAP6JzAwOYH1xTX8+LmVfP+ZFTxw2RHExhistTz96VZ+/9oaEuNjuO+SaZx52BAAzjs8F5/fYd2uagqKaygormFDSQ0GQ0ZyPBnJ8fSLj8Xjt3h8Do61zBibybFjM1t+4Vx05AisteyqamRQSkKHyaHR52ftjmoWrCvhvgUFnPb3D/ndOZM5cuRA7nx7HS+sKGJAcgJ3XzyVs6cMxRjD6ZNzmHP5EVz/xHK+MWcJs/NyyRs5kMlD09tUUNbaiJ5dEIlKaxo5656FTBmezpzLu/yZConmdssZhw4hJz2J48Zl8fyyIn44azwxPaheX/5sG39+ay3TRw4kJSmOXVUNJMXHcssp43sUjzGm02QOgSr9x8+t4s3VO5m/rphrjxvd8j0/KrM/V88YxZwPN/LC8m3EGIiPjcHvWP5ywWHMPnI41lrW7qzi/gUbmDgkjbOnDG3z/n7H8t+lW3lt1XZm5w3nnKmBn4X7FxRQUe/l52dMIiEuhjsvnMK59y/ijtfX8JcLp+DxOWwsreHRRYU8s3Qrg9MSufviqcyaNJj+ia5Jie1yRYX++JJC/vDal2Qkx/O32VM4blxwbgH1yKJN/PbVNVxxzEH8YNZ4bnt+Fe+s2cWxYzP52+wpDE5L6vpNQqyguIZb/y/waSE+1mAwXHXsSL534th2WysL15fyq5dXs7G0FoDEuBhSk+Jo8DrUe/0kx8dy/PgsTp6UzYkTsrus5K21lNZ4SE2Ka/cX6N78jmXxhlL+L7+I+euK+fXXD+HCI3L3b/ARwHEsVz+2lAXrSgB49cZjOTQ3PWjv3d1kfNY/PyI2JoaXvzcjEMfK7dz09AqevOYoZozNpMHr5/4FG/i8qIJ6r596r0NSXAw3njS25efls60VzJ6zhKm5GTxxzVHttkqCqbbRx5F3vEdsjKG6wcf8H53AqFYtIq/f4e0vdrKrqpHKOg9VDT5On5zD0aP33BHI43O49N8f8/m2Sq47fgx5Bw1g6ogMVm+r5HevrmHtzmoG9k+grNbDceMyueGEMXzrkaWcddgQ7po9teV9/vzWWh5YsIFx2SkU7q7F67fExhiunjGS788aT0qEJ/LuVuiuSOhPfLyZBetK+MuFhwV9Xugdr6/hXx9tIjUxjkafw09On8DVM0b1qOoJNZ/fYe7CTRQU13DzyeMYPjC5y9cUVzewrLCcZZvLqff6SYqPJTEuht01HuatK6akupEYA2ceNpSbTxrLuMGBk2HWWpYWlvPm6h18uaOKdTurm+50Hs9VM0Zx5TEj263MfH6Hx5ds5l8fbWRHZQPp/eLJTk1kU2ktj397Ol8bkxn0/5dw+PdHG/nD61/y49MmMOeDDRw1ehD/uqL3Vfpbq3dwy7MreeXGYxmbndLpvoWltZzw1wXcfsaklgt6Grx+jrzjPWZNGszFRw7np8+vonB3HQcPSSMlMY6khFg2ltRQVF7PaYcM5rrjR3PDE8tJiAv8UhiUktjrMXTHrc+u5PnlRRw1aiD/7UFLp7XSmkZueGIZyzaX41gwBqwNXLX68zMmcfrkHJ74eDN/eWsttR4/iXExzP/RCQxtdf6pwevnpqdX4HcsE3JSmZiTyrQRA7r1sxQJoiqhN8cYilaB41h+8vwq1myv4s5vHMYhQ4NTfUUyx7F8vq2S1z/fwZMfb6bO6+eMyUMYm53Ciyu2saWsjqT4GCYNSWNiTipjs1NZsqGU974spn9CLBcdOYJZk7KZdtAAkuJjyS8s4xcvrWbtzmqOGT2Iy44+iJMnZdPoc7jwgcUUVzfywne/xpiszhNXpPm8qJLzH1jEiROymXP5Edz9/nr+8d563rj5OA4emgYEztn89tU1fH3KUGaO794nxzqPj5P/9gE7Khu45KgR/O95h3a6/33zC7jz7XUs/OmJ5A7Yk4Buf/Fz/rt0Kz7HMnxgP/58/mF8beyeX5wNXj9zF27i3nkF1Hv99E+I5YXvzmBCTtczWYLl001lzJ6zhLsvnso5U4f16r1qGn18tqWC/M1lpCbFc+lRI9p8atxRWc+db6/j8BEDuPzoyOl/B0NUJXQJnbJaD3MXbuSxxZup9fj42phBXDAtl9Mn5+xzld+XO6p4YMEGXv98B37HkhAXw8ScVFYVVTIkPYlff/1gTjskp80v3q1ldZx73yJSkuJ48bszOv2E9emmMu55fz0bSmr46ekTW3qiEPgE8MzSrXy2tYLMlESyUhMZmp7EiROzu9UK6qnyWg/n3b+IRp/Dm98/LnBVZJ2XGX+ex/HjM7n/0iOo8/i48uFPWVpYTv+EWF6+cUabaX9ev8M7X+xi5oSsNh/p73pnHffMK2Dq8AzW7qxiyW0nd9r6OvOej0iIi+HF785os33N9ipmz1nC7Lzh/Oi08ft8vZrtqKznwQUbOPWQHGaMDf8npYLiGsZk9de5m14IeUI3xgwHHgdyAAd4yFp7d2evUUKPXNUNXuq9/jazaTrbd2lhGYsLdrNsSznTRw3k5pPGdXhCadnmMr75r08YmJzAaYcMZtbBg5k+aiC1jX62V9SzpayOx5cU8vHGMjJTEshJT2L1tqqWGTxfbK/i96+tYX1xDZkpgcvNm2cB5aQl8YNZ47jwiNyg3SZs9bZKrn9iGcVVjTxxzVFMH7VnPvJf317HvfMLeOXGGfzpzbV8vHE3vzrrYO6dX0BaUjwv3TiDtKR4qhq8fK9pxsaU3HQeu3o6GckJbC2rY9ZdH3DaITl878SxnPaPD/nxaRP43olj243lo/UlXD73U3511sFcfeyofZ7XSe4DQzgS+hBgiLV2uTEmFVgGnGutXdPRa5TQD1yLC0p5eFEhCwtKaPA6LX3QZlmpiVw/cwyXTB9BQlxMyxz72BhDvdfPiIHJ3H7mJE5tmktcUedl9fZK7nr3K1ZsqWB0Vn9+f87kXlegzy0r4vYXP2dAcgIPXDaNw0e0vVqwvNbDjD/Pw7GWRp/DXbOncN7huXy6qYxL/vUxM8dn8ZuzD+Gax/LZUFLDFceM5ImPNzMqsz//uWY6v3nlC+atLWberYEe7+VzP+GrXdV89JOT9jlJWdvo47R/fEhCbAxvfP+4kHwSEXcIe8vFGPMycK+19t2O9lFCl3qPn8UbSlm+pZxB/RMZmpHEkPR+TMhJ3Sdhrd1ZxZ/fXMtRowdx1YyRbS7kamat5d01u/jTm2vZVlHPU9cexREHtX+FX6PPT3WDD8daMvsntpz4rqz3Mm/tLl5ftYP3vizm6NEDufeSaWR2cOLwL2+t5f4FG/jT+Ydy8fQRLdsfX1LIr17+gsS4GBJiY3jw8iOYMTaThetLufbxfNL7xbOzqoFbThnfcnXl/HXFXPXIUv5x0VTOPbxtj/l3r67h4UWbePY7x7T5lCAHnrAmdGPMSOBDYLK1tqqj/ZTQJVTKaj1c8MBiyus8PH/DnhOwyzaX88uXVlNQXIPHv2ftkPhYQ056EgOSE1izvQqfY8lOTeTi6SO4+aSxnbZv/I5lS1ldmyl4EPjl8suXV7OoYDdzLj+C8YP39NPzC8u46pGlpPWLb7OeuONYZv39A/onxPHKjTNa2icrtpRz/gOLufSoEfzh3M5Pmkr0C1tCN8akAB8Ad1hrX2jn+euA6wBGjBhxxObNm3t1PJGObN5dy/n3LyY5MZanrz2a/zRNoxyS3o+zpgwhLSme1KRAn39HZQPbK+opqW7k0Nx0Tjskh6m5GUGZrtpRX3tHZT2GwC+S1p74eDO/eGk1d154GFOGZ5CWFM+VD39KVYOXd354fIfLU8iBIywJ3RgTD7wGvG2tvaur/VWhS6h9trWCix9ags9v8TmWS44awc/PmBTRF47Ue/zMvHM+xdWNbbY//K08TpoY/HXPxX3CcVLUAI8BZdbaH3TnNUroEg7z1xZz7/wCfjBrXNCuKg61sloPa3dUUVYXWDwuJ71fh4tJyYEnHAn9WOAj4HMC0xYBfm6tfaOj1yihi4j0XMgX57LWLgQ0AVZEJEK49gYXIiLSlhK6iEiUUEIXEYkSSugiIlFCCV1EJEoooYuIRAkldBGRKBHWG1wYYyqB9a02pQOV7Txuvb3535lA6X4eeu/j9OT5jmJs73FX/+6LMbS3/UAcQ+tt+zuGruLvbB+Noe0YQvl91Nk+XY2hO+PpizEcZK3t+rJna23Y/hC4CUaXj1tvb7UtP1jH7cmO64+fAAAD0klEQVTz3Y25O//uizG0t/1AHMNe2/ZrDF3FrzF0fwyh/D7qzRi6M56+HkNnf8Ldcnm1m49f7WSfYBy3J893N+bu/nt/7e8Y2tt+II4hHPF3to/G0P3jdyWUY+jOePp6DB0Ka8ulN4wx+bYbaxlEMo0hMmgMfc/t8UNkjsFNJ0Uf6usAgkBjiAwaQ99ze/wQgWNwTYUuIiKdc1OFLiIinVBCFxGJEkroIiJRIioSujEmxhhzhzHmn8aYK/s6nv1hjDnBGPORMeZBY8wJfR3P/jLG9DfGLDPGnNXXsfSUMWZS0///c8aYG/o6nv1hjDnXGPMvY8zLxphT+zqe/WGMGW2MmWuMea6vY+mJpu/9x5r+/y/tixj6PKEbYx42xhQbY1bvtf10Y8w6Y0yBMea2Lt7mHGAY4AWKQhVrR4I0BgvUAEm4dwwAPwWeDU2UHQtG/NbaL6211wOzgbBPRwvSGF6y1l4LfAu4KIThtitIY9horf12aCPtnh6O53zguab//7PDHiyE90rRDq6GOh6YBqxutS0W2ACMBhKAlcDBwKHAa3v9yQZuA77T9NrnXDqGmKbXDQaedOkYZgEXE0gmZ7kt/qbXnA0sBi5x49eg1ev+Bkxz+RjC/rPcy/H8DJjatM9TfRHvft9TNFistR8aY0butXk6UGCt3QhgjHkGOMda+0dgn4/yxpgiwNP00B+6aNsXjDG0Ug4khiLOzgTp63Ai0J/AN3e9MeYNa62z936hEKyvgbX2FeAVY8zrwFOhi7jdYwfja2CAPwFvWmuXhzbifQX5Z6HP9WQ8BD5Z5wKf0Ufdjz5P6B0YBmxt9bgIOKqT/V8A/mmMOQ74MJSB9UCPxmCMOR84DcgA7g1taN3WozFYa28HMMZ8CygNVzLvRE+/BicQ+NicCLwR0si6r6c/CzcR+KSUbowZa619MJTBdVNPvw6DgDuAw40xP2tK/JGko/HcA9xrjDmT4CwP0GORmtBNO9s6vALKWlsHRETPrZWejuEFAr+YIkmPxtCyg7WPBj+U/dLTr8ECYEGogtlPPR3DPQQSSyTp6Rh2A9eHLpxea3c81tpa4KpwB9Nan58U7UARMLzV41xgex/Fsr80hr7n9vhBY4hEETueSE3oS4FxxphRxpgEAifaXunjmHpKY+h7bo8fNIZIFLnjiYCzyE8DO9gz5fDbTdvPAL4icDb59r6OU2OI7DG4PX6NITL/uG08WpxLRCRKRGrLRUREekgJXUQkSiihi4hECSV0EZEooYQuIhIllNBFRKKEErqISJRQQhcRiRJK6CIiUeL/AbivS8y5wKgIAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.recorder.plot(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def create_phases(phases):\n",
    "    phases = listify(phases)\n",
    "    return phases + [1-sum(phases)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.3, 0.7]\n",
      "[0.3, 0.2, 0.5]\n"
     ]
    }
   ],
   "source": [
    "print(create_phases(0.3))\n",
    "print(create_phases([0.3,0.2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-2\n",
    "pct_start = 0.5\n",
    "phases = create_phases(pct_start)\n",
    "sched_lr  = combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))\n",
    "sched_mom = combine_scheds(phases, cos_1cycle_anneal(0.95, 0.85, 0.95))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cbsched = [\n",
    "    ParamScheduler('lr', sched_lr),\n",
    "    ParamScheduler('mom', sched_mom)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(arch(), data, loss_func, lr=lr, cb_funcs=cbfs, opt_func=opt_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>valid_accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.701341</td>\n",
       "      <td>0.486505</td>\n",
       "      <td>1.767135</td>\n",
       "      <td>0.510000</td>\n",
       "      <td>00:38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.488563</td>\n",
       "      <td>0.590507</td>\n",
       "      <td>1.830278</td>\n",
       "      <td>0.514000</td>\n",
       "      <td>00:38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.345791</td>\n",
       "      <td>0.651078</td>\n",
       "      <td>1.440738</td>\n",
       "      <td>0.638000</td>\n",
       "      <td>00:38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.174334</td>\n",
       "      <td>0.727393</td>\n",
       "      <td>1.005328</td>\n",
       "      <td>0.792000</td>\n",
       "      <td>00:38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.019539</td>\n",
       "      <td>0.790213</td>\n",
       "      <td>0.912079</td>\n",
       "      <td>0.824000</td>\n",
       "      <td>00:39</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit(5, cbs=cbsched)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## cnn_learner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=2711)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def cnn_learner(arch, data, loss_func, opt_func, c_in=None, c_out=None,\n",
    "                lr=1e-2, cuda=True, norm=None, progress=True, mixup=0, xtra_cb=None, **kwargs):\n",
    "    cbfs = [partial(AvgStatsCallback,accuracy)]+listify(xtra_cb)\n",
    "    if progress: cbfs.append(ProgressCallback)\n",
    "    if cuda:     cbfs.append(CudaCallback)\n",
    "    if norm:     cbfs.append(partial(BatchTransformXCallback, norm))\n",
    "    if mixup:    cbfs.append(partial(MixUp, mixup))\n",
    "    arch_args = {}\n",
    "    if not c_in : c_in  = data.c_in\n",
    "    if not c_out: c_out = data.c_out\n",
    "    if c_in:  arch_args['c_in' ]=c_in\n",
    "    if c_out: arch_args['c_out']=c_out\n",
    "    return Learner(arch(**arch_args), data, loss_func, opt_func=opt_func, lr=lr, cb_funcs=cbfs, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(xresnet34, data, loss_func, opt_func, norm=norm_imagenette)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(5, cbsched)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imagenet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see all this put together in the fastai [imagenet training script](https://github.com/fastai/fastai/blob/master/examples/train_imagenet.py). It's the same as what we've seen so far, except it also handles multi-GPU training. So how well does this work?\n",
    "\n",
    "We trained for 60 epochs, and got an error of 5.9%, compared to the official PyTorch resnet which gets 7.5% error in 90 epochs! Our xresnet 50 training even surpasses standard resnet 152, which trains for 50% more epochs and has 3x as many layers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 11_train_imagenette.ipynb to exp/nb_11.py\r\n"
     ]
    }
   ],
   "source": [
    "!./notebook2script.py 11_train_imagenette.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
